# 雷电预警MLP
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

class TlpMlp(nn.Module):
    def __init__(self):
        super(TlpMlp, self).__init__()
        self.fns = nn.Sequential(
            nn.Linear(300, 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 5)
        )

    def forward(self, x):
        logits = self.fns(x)
        return logits